resources:
  accelerators: A100-80GB:8
  disk_size: 1000
  use_spot: true
  disk_tier: best

num_nodes: 1

file_mounts:
  /artifacts:
    name: YOUR_OWN_BUCKET_NAME # Change to your own bucket name
    mode: MOUNT
  /data/mydata.json: ./dummy.json

workdir: .

setup: |
  # Setup the environment
  conda activate chatbot
  if [ $? -ne 0 ]; then
    conda create -n chatbot python=3.10 -y
    conda activate chatbot
  fi

  # Install pytorch
  pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

  # Install huggingface with the LLaMA commit
  git clone https://github.com/huggingface/transformers.git
  cd transformers
  git checkout 41a2f3529c6b56866c317031375ffd3e7b8bea01
  pip install .
  cd -

  # Install fastchat
  git clone https://github.com/lm-sys/FastChat.git
  cd FastChat
  pip install -e .
  if [ $USE_FLASH_ATTN -eq 1 ]; then
    pip install flash-attn
  fi

run: |
  cd FastChat
  conda activate chatbot
  USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
  if [ $USE_FLASH_ATTN -eq 1 ]; then
    TRAIN_SCRIPT=fastchat/train/train_mem.py
    USE_FLASH_SUFFIX="-flash"
  else
    TRAIN_SCRIPT=fastchat/train/train.py
    USE_FLASH_SUFFIX=""
  fi
  PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
  NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`

  # Do the periodic syncing manually, to avoid the degradation of
  # the training for saving checkpoints.
  mkdir -p ~/.checkpoints
  LOCAL_CKPT_PATH=~/.checkpoints
  CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/skypilot-chatbot
  mkdir -p $CKPT_PATH
  last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
  mkdir -p ~/.checkpoints/${last_ckpt}
  gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints/${last_ckpt}

  bash ../scripts/sync_local_checkpoint.sh ${LOCAL_CKPT_PATH} ${CKPT_PATH} > sync.log 2>&1 &
  
  # Turn off wandb if no api key is provided
  if [ $WANDB_API_KEY == "" ]; then
    WANDB_MODE="offline"
  fi
  
  torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
    --master_port=12375 \
    --master_addr=$HOST_ADDR \
    --node_rank=${SKYPILOT_NODE_RANK} \
    $TRAIN_SCRIPT \
    --model_name_or_path huggyllama/llama-${MODEL_SIZE}b \
    --data_path /data/mydata.json \
    --bf16 True \
    --output_dir $LOCAL_CKPT_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
    --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
    --gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 600 \
    --save_total_limit 3 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length ${SEQ_LEN} \
    --run_name $SKYPILOT_JOB_ID \
    --gradient_checkpointing True \
    --lazy_preprocess True

  returncode=$?
  # Sync any files not in the checkpoint-* folders
  gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/
  exit $returncode


envs:
  MODEL_SIZE: 7
  SEQ_LEN: 2048
  GC_SCALE: 4
  USE_FLASH_ATTN: 0
  WANDB_API_KEY: ""
